import sys
import math
import time
from multiprocessing import Pool, cpu_count


def load_ratings(filename):
    """
    OBJ: Load ratings from CSV file
    str -> dict, dict
    """
    users = {}
    movies = {}
    
    with open(filename, 'r') as f:
        # Skip header
        next(f)
        
        for line in f:
            line = line.strip()
            if not line:
                continue
                
            parts = line.split(',')
            user_id = int(parts[0])
            movie_id = int(parts[1])
            rating = float(parts[2])
            # timestamp = int(parts[3])  # Not needed for now
            
            if user_id not in users:
                users[user_id] = {}
            if movie_id not in movies:
                movies[movie_id] = {}
            
            users[user_id][movie_id] = rating
            movies[movie_id][user_id] = rating
    
    return users, movies


def cosine_similarity(movie1_ratings, movie2_ratings):
    """
    OBJ: Compute cosine similarity between two movies based on user ratings
    dict, dict -> float
    """
    # Find common users who rated both movies
    common_users = set(movie1_ratings.keys()) & set(movie2_ratings.keys())
    
    if len(common_users) == 0:
        return 0.0
    
    # Compute dot product and magnitudes
    dot_product = 0.0
    magnitude1 = 0.0
    magnitude2 = 0.0
    
    for user in common_users:
        r1 = movie1_ratings[user]
        r2 = movie2_ratings[user]
        
        dot_product += r1 * r2
        magnitude1 += r1**2; magnitude2 += r2**2
    
    magnitude1 = math.sqrt(magnitude1)
    magnitude2 = math.sqrt(magnitude2)
    
    if magnitude1 == 0 or magnitude2 == 0:
        return 0.0
    
    return dot_product / (magnitude1 * magnitude2)


def compute_similarities_for_chunk(args):
    """
    OBJ: Compute similarities for a chunk of movie pairs (worker function)
    tuple -> list
    """
    movie_ids, start_idx, end_idx, movies, similarity_threshold = args
    results = []
    
    for i in range(start_idx, end_idx):
        movie1 = movie_ids[i]
        
        for movie2 in movie_ids[i+1:]:
            sim = cosine_similarity(movies[movie1], movies[movie2])
            
            if sim >= similarity_threshold:
                results.append((movie1, movie2, sim))
    
    return results


def compute_movie_similarities(movies, similarity_threshold=0.0, num_threads=None):
    """
    OBJ: Compute similarities between all movie pairs using multiprocessing
    dict, float, int -> dict
    """
    if num_threads is None:
        num_threads = cpu_count()
    
    movie_ids = list(movies.keys())
    similarities = {}
    
    print(f"Computing similarities for {len(movie_ids)} movies using {num_threads} threads...")
    
    # Split work into chunks
    chunk_size = max(1, len(movie_ids) // num_threads)
    chunks = []
    
    for i in range(0, len(movie_ids), chunk_size):
        end_idx = min(i + chunk_size, len(movie_ids))
        chunks.append((movie_ids, i, end_idx, movies, similarity_threshold))
    
    # Process chunks in parallel
    with Pool(num_threads) as pool:
        chunk_results = pool.map(compute_similarities_for_chunk, chunks)
    
    # Merge results
    for chunk_result in chunk_results:
        for movie1, movie2, sim in chunk_result:
            if movie1 not in similarities:
                similarities[movie1] = []
            if movie2 not in similarities:
                similarities[movie2] = []
            
            similarities[movie1].append((movie2, sim))
            similarities[movie2].append((movie1, sim))
    
    # Sort each movie's similar movies by similarity (descending)
    for movie_id in similarities:
        similarities[movie_id].sort(key=lambda x: x[1], reverse=True)
    
    print("Similarity computation complete!")
    return similarities


def predict_rating(user_id, movie_id, users, similarities, k=10):
    """
    OBJ: Predict rating for a user-movie pair using Item-Item collaborative filtering
    int, int, dict, dict, int -> float or None
    """
    if movie_id not in similarities:
        return None
    
    user_ratings = users[user_id]
    
    # Get k most similar movies that the user has rated
    similar_movies = similarities[movie_id][:k]
    
    weighted_sum = 0.0
    similarity_sum = 0.0
    
    for other_movie, sim in similar_movies:
        if other_movie in user_ratings:
            weighted_sum += sim * user_ratings[other_movie]
            similarity_sum += abs(sim)
    
    if similarity_sum == 0:
        return None
    
    return weighted_sum / similarity_sum


def recommend_for_user(user_id, users, movies, similarities, k=10, top_n=5):
    """
    OBJ: Generate top-N recommendations for a single user
    int, dict, dict, dict, int, int -> list
    """
    if user_id not in users:
        print(f"User {user_id} not found in dataset")
        return []
    
    user_ratings = users[user_id]
    all_movie_ids = set(movies.keys())
    
    # Find movies the user hasn't rated
    unrated_movies = all_movie_ids - set(user_ratings.keys())
    
    # Predict ratings for unrated movies
    predictions = []
    for movie_id in unrated_movies:
        pred = predict_rating(user_id, movie_id, users, similarities, k)
        if pred is not None:
            predictions.append((movie_id, pred))
    
    # Sort by predicted rating and take top N
    predictions.sort(key=lambda x: x[1], reverse=True)
    return predictions[:top_n]


def generate_recommendations(users, movies, similarities, similarity_threshold, k=10):
    """
    OBJ: Generate recommendations for all users
    dict, dict, dict, float, int -> list
    """
    recommendations = []
    all_movie_ids = set(movies.keys())
    
    print(f"Generating recommendations for {len(users)} users...")
    
    user_items = list(users.items())
    for i in range(len(user_items)):
        user_id, user_ratings = user_items[i]
        if i % 100 == 0:
            print(f"Processed {i}/{len(users)} users")
        
        # Use recommend_for_user but only take the best recommendation
        user_recs = recommend_for_user(user_id, users, movies, similarities, k, top_n=1)
        if user_recs:
            movie_id, pred_rating = user_recs[0]
            recommendations.append((user_id, movie_id, pred_rating))
    
    print("Recommendations complete!")
    return recommendations


if len(sys.argv) < 2:
    print("Usage: ./collab_filter.py <ratings_file> [similarity_threshold] [user_id] [num_threads]")
    print("*If user_id is provided, shows recommendations only for that user")
    print("*num_threads: number of threads to use (default: all CPU cores)")
    sys.exit(1)

ratings_file = sys.argv[1]
similarity_threshold = float(sys.argv[2]) if len(sys.argv) > 2 else 0.0
single_user_id = int(sys.argv[3]) if len(sys.argv) > 3 else None
num_threads = int(sys.argv[4]) if len(sys.argv) > 4 else None

# Load data
print("Loading ratings...")
start_time = time.time()
users, movies = load_ratings(ratings_file)
load_time = time.time() - start_time
print(f"Loaded {len(users)} users and {len(movies)} movies in {load_time:.2f} seconds")

# Compute similarities
print("\nComputing similarities...")
start_time = time.time()
similarities = compute_movie_similarities(movies, similarity_threshold, num_threads)
sim_time = time.time() - start_time
print(f"Similarity computation took {sim_time:.2f} seconds")

# Generate recommendations
print("\nGenerating recommendations...")
start_time = time.time()
if single_user_id is not None:
    # Single user mode
    print(f"\nTop recommendations for user {single_user_id}:")
    recommendations = recommend_for_user(single_user_id, users, movies, similarities)
    for movie_id, rating in recommendations:
        print(f"{single_user_id} {movie_id} {rating:.1f}")
else:
    # All users mode
    recommendations = generate_recommendations(users, movies, similarities, similarity_threshold)
    for user_id, movie_id, rating in recommendations:
        print(f"{user_id} {movie_id} {rating:.1f}")
rec_time = time.time() - start_time
print(f"\nRecommendation generation took {rec_time:.2f} seconds")

total_time = load_time + sim_time + rec_time
print(f"\n{'='*50}")
print(f"TOTAL EXECUTION TIME: {total_time:.2f} seconds")
print(f"{'='*50}")